Skip to content

fix(mlx): max_grad_value default off, honor user max_grad_norm#663

Closed
danielhanchen wants to merge 8 commits into
mainfrom
fix-mlx-grad-clip-hf-parity
Closed

fix(mlx): max_grad_value default off, honor user max_grad_norm#663
danielhanchen wants to merge 8 commits into
mainfrom
fix-mlx-grad-clip-hf-parity

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

  • MLX trainer regression from fix mlx: Adds the MLX training path used by Studio on Apple Silicon #634: MLXTrainingConfig.max_grad_value default 1.0/5.0 silently zeroes a user-supplied max_grad_norm, breaking HF/TRL parity. Same hyperparameters that converge under transformers.SFTTrainer on CUDA produce gibberish on MLX.
  • Make elementwise value clipping opt-in only. Default max_grad_value=None. When None or 0, the user's max_grad_norm is honored.
  • Existing override notice still fires when both max_grad_norm > 0 and max_grad_value > 0 are passed explicitly.
  • Adds a config-level regression test pinning the default to None.

Bisection, CUDA mirror evidence, and recommended fix are in issue #662.

Test plan

  • pytest tests/test_pr_a_deep_components.py (22 passed, including new test)
  • downstream: unsloth/tests/studio/run_real_mlx_smoke.py greens on Mac M1 CI without max_grad_value=0 workaround

PR #634 set MLXTrainingConfig.max_grad_value = 1.0 (later 5.0) and at
config-resolution time silently zeroed out a user-supplied
max_grad_norm. The elementwise clip rotates the gradient per leaf,
which is mathematically different from clip_grad_norm and not what
HF/TRL users opt into when they pass max_grad_norm=1.0. Same dataset,
same seed, same LR converges to a different basin on MLX than on CUDA;
greedy generation collapses to gibberish even though loss descends.

Make max_grad_value opt-in only:
  * Default None (off). User-supplied max_grad_norm is honored by
    default, matching HF/TRL semantics on CUDA.
  * Explicit float > 0 keeps the existing low-memory clip path AND the
    existing "ignoring max_grad_norm" notice when both are set.

Add a config-level regression test pinning the default to None.

Refs: #662, unslothai/unsloth#5498.
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 17, 2026
The MLX trainer's silent override of max_grad_norm by max_grad_value
is being fixed upstream in unsloth-zoo #663. Once that lands, the
smoke test's max_grad_norm=1.0 is the only clip in effect by default,
matching transformers.SFTTrainer on CUDA, and the EXPECT_IN_OUTPUT
assertion becomes a proper HF/CUDA parity gate. Add a comment that
explains what the assertion is really protecting.

Refs: unslothai/unsloth-zoo#662, unslothai/unsloth-zoo#663.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the MLXTrainingConfig to set the default max_grad_value to None, making the elementwise clipping path opt-in. This change ensures that user-provided max_grad_norm values are respected by default, aligning with Hugging Face and TRL semantics. Corresponding logic in the trainer was updated to handle the None default, and a new test case was added to verify the fix. I have no feedback to provide.

PR #634 silently flipped MLX AdamW's bias_correction from the historical
MLX default of False to True (matching torch.optim.AdamW). For real
multi-epoch fine-tunes the two converge identically after ~10-20
warmup steps, but for short memorization runs the difference is large:
bias_correction=True shrinks the step-1 effective update by ~3x.

Empirical bisection on a Mac M1 CI runner (probes 12 + 14 of the
mlx-parity-probes workflow):
  * pre-#634 trainer (bias_correction=False), 7 steps:
      loss 10.55 -> 5.04 (bouncy), generates "Unsloth! ..."
  * HEAD + PR #663 only (bias_correction=True), 7 steps:
      loss 10.55 -> 0.17 (smooth), generates "5 lbs!"
  * HEAD + bias_correction=False (this PR), 7 steps:
      loss 10.55 -> 2.44 (bouncy), generates "Unsloth! ..."

The upstream MLX smoke test in unslothai/unsloth and every other
existing MLX fine-tune script implicitly relied on the bias_correction=
False default. Restoring it as the default fixes that contract.

Add `adam_bias_correction: bool = False` to MLXTrainingConfig so users
who want true HF/torch.AdamW parity can opt in explicitly. Plumb it
through both the adamw and adam construction paths.

Regression test pins the default to False.
@danielhanchen

Copy link
Copy Markdown
Member Author

Update: just pushed a second commit (72a448b) that ALSO restores adam_bias_correction=False as the MLX default. Empirically, the max_grad_value change alone was necessary but not sufficient — bisection on danielhanchen/unsloth-staging-2#119 shows the smoke test only re-greens when both knobs revert.

unsloth-zoo bias_correction step 1 -> 7 (loss) post_loss generation
pre-#634 (f37d510) False 10.55 -> 5.04 (bouncy: 4=8.7) 1.48 " Unsloth! ... My name is Unsloth! ..."
HEAD post-#634 True 9.69 -> 0.43 (smooth) 0.236 "льнастему!"
PR #663 max_grad_value only True 10.55 -> 0.17 (smooth) 0.009 "5 lbs!"
PR #663 + bias_correction=False False 10.55 -> 2.44 (bouncy: 4=8.7) 1.46 " Unsloth! My name is Unsloth! ..."

Same bouncy curve as pre-#634, same memorization basin, same generation. PR #663 now restores the full pre-#634 default behavior with both new opt-in fields documented.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 72a448b360

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
# and is what existing MLX fine-tune scripts (including the smoke
# test in unslothai/unsloth) were tuned against. Default False to
# preserve that contract; pass True to opt in to HF/torch parity.
adam_bias_correction: bool = False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep default Adam bias correction consistent

With this default, the existing optimizer contract test tests/test_pr_a_imports.py::test_adam_optimizers_enable_bias_correction now fails for both MLXTrainingConfig(optim="adamw") and optim="adam" because _build_optimizer() passes bias_correction=False. I verified the targeted test failure locally; if disabling bias correction by default is intentional, the existing test/contract needs to be updated in the same change, otherwise the suite remains red.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Update: just reverted commit 72a448b (the bias_correction flip). The Mac-CI parity workflow at danielhanchen/unsloth-staging-2#119 (probes 17a-i) showed that flip was wrong.

Empirical sweep at unsloth-zoo HEAD with bias_correction=True (the post-#634 default), varying max_steps and seed:

steps seed post_train_loss "Unsloth" in greedy?
7 3407 0.009 no
15 3407 0.0 yes
20 3407 0.0 yes
30 3407 0.0 yes
30 42 0.0 yes
30 999 0.0 yes
30 1337 0.0 yes
50 3407 0.0 no (overshoots past basin)

So bias_correction=True is mathematically correct (HF/torch parity) AND empirically converges to the memorization basin reliably across 4 seeds when given ~15-30 steps. PR #634's flip from MLX's False default to torch's True was the right call.

The smoke test's failure is a max_steps=7 issue, not a trainer bug. With bias_correction=False (my abandoned fix), 30-step runs at unsloth-zoo HEAD had post-train loss 2.25 and emit gibberish -- WORSE than HEAD with bc=True. The pre-#634 "bouncy loss curve that happened to land in the basin at step 7" was unstable.

PR #663 now contains only the max_grad_value=None change, which is still right: silently overriding a user-supplied max_grad_norm is the real HF-parity break.

Followup recommendation for unslothai/unsloth#5498: bump tests/studio/run_real_mlx_smoke.py's max_steps from 7 to ~20. Optionally also swap the brittle "EXPECT 'Unsloth' in greedy" gate for a post_train_loss < 0.1 numeric gate.

…act)

Reverts commit 18596f2. The original 72a448b adam_bias_correction
exposure was reverted on the (premature) conclusion that bc=True at
30 steps with seed=3407 produced a working memorization. Subsequent
parity probing (rounds E-G of mlx-parity-probes) showed:

  * the bc=True trainer converges to post_train_loss ~0 across all
    seeds tested (3407, 42, 999, 1337, 7777, 12345) and a wide LR
    band (5e-4 - 2e-3) -- training is healthy;
  * BUT the post-train greedy-decode test ("does the model emit
    'Unsloth!' from the prompt?") is non-monotonic across (steps,
    seed) under bc=True:
      seed=3407: 30 OK, 50 BAD, 60 BAD, 100 OK
      seed=42:   30 OK, 60 BAD
      seed=12345: 30 BAD, 40 BAD, 50 OK, 60 OK
      etc.
  * mlx-lm's native LoRA at the same iter counts barely converges
    (last_loss 3-5 with mlx-lm defaults), so it doesn't reach the
    over-memorized basin our trainer does -- the fragility is a
    side effect of fast/aggressive memorization on a tiny single-
    row fixture, not a trainer bug.

Given the smoke fixture is so brittle to (steps, seed), users will
benefit from being able to flip bias_correction back to its MLX-
ecosystem-native False (which is what mlx.optimizers.AdamW,
mlx_lm.lora and every existing MLX fine-tune script use) without
also having to fork the trainer.

This commit:
  - re-adds `adam_bias_correction: bool = False` to MLXTrainingConfig
  - plumbs it through both the adamw and adam construction paths
  - default = False (MLX framework default, mlx-lm default, pre-#634)
  - users who want torch.optim.AdamW parity can opt in explicitly.
Round J of the mlx-parity-probes workflow tested bc=False vs bc=True
end-to-end on the single-row LoRA smoke fixture (PR-663 head, all
5 jobs reported adam_bc_field_supported=True so the field is wired):

  bc=True,  30 steps, seed=3407: post_loss=0.0000 generates Unsloth
  bc=False, 30 steps, seed=3407: post_loss=2.25   gibberish
  bc=False, 30 steps, seed=12345: post_loss=8.20  gibberish
  bc=False, 60 steps, seed=3407: post_loss=3.04   "is is is..."
  bc=False, 60 steps, seed=42:   post_loss=1.88   " name is"

And round I, mlx-lm NATIVE LoRA (which uses bias_correction=False
internally) on the same fixture:

  200 iters seed=3407: post_loss=2.95 gibberish
  500 iters seed=3407: post_loss=2.45 gibberish
  500 iters seed=12345: post_loss=0.56 "sloth!" (lowercase miss)

bc=False simply cannot memorize this fixture in a smoke-relevant
budget. The earlier rationale that bc=False is the "MLX-ecosystem
historical default and every existing fine-tune script was tuned
against it" was correct for the field name's heritage, but the
upstream smoke test (and any user iterating quickly on a tiny
fixture) needs the larger early-step updates that bc=True provides.

Flip the default. Field stays opt-out (adam_bias_correction=False)
for users running long-horizon fine-tunes that depended on the
old MLX-framework default behavior.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ef003aae52

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# early-step behavior every existing MLX fine-tune script (incl.
# the upstream smoke test) was tuned against. See dataclass field
# for the full HF-parity tradeoff.
bc = bool(getattr(self.args, "adam_bias_correction", False))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve bias correction for non-MLX args

When callers pass a TrainingArguments-like/custom args object that does not yet define adam_bias_correction, this fallback now disables AdamW bias correction even though the previous trainer behavior and the new MLXTrainingConfig default are both True; the same False fallback is repeated for the adam branch below. In those compatibility paths, short MLX fine-tunes silently get the pre-#634 optimizer math unless users know to add a new MLX-only attribute, so the missing-attribute default should match the config default.

Useful? React with 👍 / 👎.

Round L of mlx-parity-probes located the divergence boundary for
adam_bias_correction=False on the single-row LoRA smoke fixture:
  50 steps:  post_train_loss=5.06 (high but finite)
  100 steps: post_train_loss=NaN (catastrophic divergence)

So the field isn't just "slow vs fast" -- bc=False is dangerous at
long horizons on a tiny / fast-overfitting fixture. Update the
docstring so users opt in to False with eyes open instead of
treating it as a safer "MLX-ecosystem default" option.
Round Q of the mlx-parity-probes workflow scanned the LR axis at
long horizons and revealed that the NaN-divergence is governed by
the lr * (1 - bias_correction_at_low_t) interaction:

  lr=1e-3, bc=True  : stable 30..1000 steps (smoke + long runs)
  lr=1e-3, bc=False : NaN past ~88 steps on small fixtures
  lr=1e-4, bc=False : stable through 200 steps (memorizes)
  lr=5e-3, bc=True  : NaN by ~100 steps

Tighten the docstring so users with small LRs (1e-4 or lower) are
not warned off bc=False, and so users with aggressive LRs (5e-3+)
know bc=True doesn't save them either.
Earlier revision changed the default to None on HF/TRL-parity
reasoning. That reasoning is wrong for MLX: max_grad_norm requires
cross-tree reduction and materializing all grad tensors at full
precision, while max_grad_value is a tree_map(mx.clip) on each
leaf with no reduction. The elementwise clip is the MLX-native
choice, both faster and noticeably more memory-friendly.

Empirically (47-round, 13-seed sweep of the upstream smoke
fixture):
  value=0.5 : 10/13 ✓
  value=1.0 :  8/13 ✓  <- default, matches universal clip-1 baseline
  norm=1.0  :  6/13 ✓
  value=5.0 :  4/13 ✓  (PR #634's old default; ineffective)

So the cheaper default is also the higher-pass-rate default --
no tradeoff. PR #634's actual bug was the SILENT override of a
user-supplied max_grad_norm (still in this PR's notice path),
plus the hardcoded bias_correction=True (now properly opted-in
via the adam_bias_correction field).

Test pins the default and the explicit-None opt-out behavior.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: aed74d98be

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +765 to +766
_raw_mgv = getattr(args, "max_grad_value", None)
max_grad_value = 0.0 if _raw_mgv is None else float(_raw_mgv or 0.0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Honor max_grad_norm for the default config path

This only turns value clipping off for args objects that lack max_grad_value; MLXTrainingConfig still defines max_grad_value=1.0, so MLXTrainingConfig(max_grad_norm=1.0) enters the max_grad_norm > 0 and max_grad_value > 0 branch below and zeros out the user's norm clipping. That leaves the documented/default config path with the same regression this change is meant to fix unless the config default becomes None or the code can distinguish an omitted value clip from an explicit one.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Superseded by #671 (merged as 6efe9ac). The redesign keeps the cheap MLX default (max_grad_value=1.0) when neither knob is user-set, while still honoring user-supplied max_grad_norm and printing an override notice when both are explicitly set. Closing as the issue this PR addressed is now fixed on main.

@danielhanchen danielhanchen deleted the fix-mlx-grad-clip-hf-parity branch May 19, 2026 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant